// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include <wrl.h>
#include <wrl/ftm.h>
//
// Adapted from Windows DXaml comTemplateLibrary.h
//

namespace ctl {

#define ValidateOutputArgPtr(o) \
  if ((LPVOID)o == NULL)        \
    return E_POINTER;
#define CheckAllocationPtr(o) \
  if (o == NULL)              \
    return E_OUTOFMEMORY;

interface INonDelegatingUnknown {
 public:
  virtual HRESULT STDMETHODCALLTYPE NonDelegatingQueryInterface(
      /* [in] */ REFIID riid,
      /* [iid_is][out] */
      __RPC__deref_out void __RPC_FAR *__RPC_FAR *ppvObject) = 0;
  virtual ULONG STDMETHODCALLTYPE NonDelegatingAddRef() = 0;
  virtual ULONG STDMETHODCALLTYPE NonDelegatingRelease() = 0;
};

interface INonDelegatingInspectable : public INonDelegatingUnknown {
  virtual HRESULT STDMETHODCALLTYPE NonDelegatingGetIids(
      /* [out] */ __RPC__out ULONG * iidCount,
      /* [size_is][size_is][out] */
      __RPC__deref_out_ecount_full_opt(*iidCount) IID * *iids) = 0;

  virtual HRESULT STDMETHODCALLTYPE NonDelegatingGetRuntimeClassName(
      /* [out] */ __RPC__deref_out_opt HSTRING * className) = 0;

  virtual HRESULT STDMETHODCALLTYPE NonDelegatingGetTrustLevel(
      /* [out] */ __RPC__out TrustLevel * trustLevel) = 0;
};

template <class TBASE>
class AggregableComObject : public INonDelegatingInspectable, public TBASE {
 public:
  AggregableComObject(__in IInspectable *pOuter) {
    static_assert(
        __is_base_of(::Microsoft::WRL::Details::RuntimeClassBase, TBASE),
        "AggregableComObject can only be used with ::Microsoft::WRL::RuntimeClass types");

    m_pControllingUnknown = pOuter;
  }

  template <class... TArgs>
  AggregableComObject(__in IInspectable *pOuter, TArgs &&... args) : TBASE(std::forward<TArgs>(args)...) {
    static_assert(
        __is_base_of(::Microsoft::WRL::Details::RuntimeClassBase, TBASE),
        "AggregableComObject can only be used with ::Microsoft::WRL::RuntimeClass types");

    m_pControllingUnknown = pOuter;
  }

  // IInspectable (non-delegating) implementation
  IFACEMETHODIMP NonDelegatingQueryInterface(REFIID iid, void **ppInterface) override {
    ValidateOutputArgPtr(ppInterface);
    IUnknown *pInterface = nullptr;

    if (iid == IID_IUnknown) {
      pInterface = reinterpret_cast<IUnknown *>(static_cast<ctl::INonDelegatingUnknown *>(this));
    } else if (iid == IID_IInspectable) {
      pInterface = reinterpret_cast<IInspectable *>(static_cast<ctl::INonDelegatingInspectable *>(this));
    } else {
      return TBASE::QueryInterface(iid, ppInterface);
    }

    *ppInterface = pInterface;
    pInterface->AddRef();
    return S_OK;
  }

  IFACEMETHODIMP_(ULONG) NonDelegatingAddRef() override {
    return TBASE::AddRef();
  }

  IFACEMETHODIMP_(ULONG) NonDelegatingRelease() override {
    return TBASE::Release();
  }

  IFACEMETHODIMP NonDelegatingGetRuntimeClassName(__out HSTRING *pClassName) override {
    return TBASE::GetRuntimeClassName(pClassName);
  }

  IFACEMETHODIMP NonDelegatingGetTrustLevel(__out TrustLevel *trustLevel) override {
    return TBASE::GetTrustLevel(trustLevel);
  }

  IFACEMETHODIMP NonDelegatingGetIids(__out ULONG *iidCount, __deref_out IID **iids) override {
    return TBASE::GetIids(iidCount, iids);
  }

  // IInspectable (delegating) implementation
  IFACEMETHODIMP QueryInterface(REFIID iid, void **ppValue) override {
    return m_pControllingUnknown->QueryInterface(iid, ppValue);
  }

  IFACEMETHODIMP_(ULONG) AddRef() override {
    return m_pControllingUnknown->AddRef();
  }

  IFACEMETHODIMP_(ULONG) Release() override {
    return m_pControllingUnknown->Release();
  }

  IFACEMETHODIMP GetRuntimeClassName(__out HSTRING *pClassName) override {
    return m_pControllingUnknown->GetRuntimeClassName(pClassName);
  }

  IFACEMETHODIMP GetTrustLevel(__out TrustLevel *trustLvl) override {
    return m_pControllingUnknown->GetTrustLevel(trustLvl);
  }

  IFACEMETHODIMP GetIids(__out ULONG *iidCount, __deref_out IID **iids) override {
    return m_pControllingUnknown->GetIids(iidCount, iids);
  }

 public:
  static __checkReturn HRESULT CreateInstance(__in IInspectable *pOuter, __deref_out TBASE **ppNewInstance) {
    ValidateOutputArgPtr(ppNewInstance);
    ::Microsoft::WRL::ComPtr<ctl::AggregableComObject<TBASE>> pNewInstance(
        ::Microsoft::WRL::Make<ctl::AggregableComObject<TBASE>>(pOuter));
    CheckAllocationPtr(pNewInstance);
    *ppNewInstance = static_cast<TBASE *>(pNewInstance.Detach());
    return S_OK;
  }

  template <class... TArgs>
  static __checkReturn HRESULT
  CreateInstance(__in IInspectable *pOuter, __deref_out TBASE **ppNewInstance, TArgs &&... args) {
    ValidateOutputArgPtr(ppNewInstance);
    ::Microsoft::WRL::ComPtr<ctl::AggregableComObject<TBASE>> pNewInstance(
        ::Microsoft::WRL::Make<ctl::AggregableComObject<TBASE>>(pOuter, std::forward<TArgs>(args)...));
    CheckAllocationPtr(pNewInstance);
    *ppNewInstance = static_cast<TBASE *>(pNewInstance.Detach());
    return S_OK;
  }

  template <class T>
  static __checkReturn HRESULT CreateInstance(__in IInspectable *pOuter, __deref_out T **ppNewInstance) {
    ValidateOutputArgPtr(ppNewInstance);
    ::Microsoft::WRL::ComPtr<ctl::AggregableComObject<TBASE>> pNewInstance(
        ::Microsoft::WRL::Make<ctl::AggregableComObject<TBASE>>(pOuter));
    CheckAllocationPtr(pNewInstance);
    return pNewInstance.CopyTo(ppNewInstance);
  }

  template <class T, class... TArgs>
  static __checkReturn HRESULT
  CreateInstance(__in IInspectable *pOuter, __deref_out T **ppNewInstance, TArgs &&... args) {
    ValidateOutputArgPtr(ppNewInstance);
    ::Microsoft::WRL::ComPtr<ctl::AggregableComObject<TBASE>> pNewInstance(
        ::Microsoft::WRL::Make<ctl::AggregableComObject<TBASE>>(pOuter, std::forward<TArgs>(args)...));
    CheckAllocationPtr(pNewInstance);
    return pNewInstance.CopyTo(ppNewInstance);
  }

 private:
  IInspectable *m_pControllingUnknown;
};

template <class TBASE>
class AggregableComFactory {
 public:
  static HRESULT STDMETHODCALLTYPE ActivateInstance(IInspectable **ppInstance) {
    ValidateOutputArgPtr(ppInstance);
    ::Microsoft::WRL::ComPtr<TBASE> instance(::Microsoft::WRL::Make<TBASE>());
    CheckAllocationPtr(instance);
    return instance.CopyTo(ppInstance);
  }

  template <class... TArgs>
  static HRESULT STDMETHODCALLTYPE ActivateInstance(IInspectable **ppInstance, TArgs &&... args) {
    ValidateOutputArgPtr(ppInstance);
    ::Microsoft::WRL::ComPtr<TBASE> instance(::Microsoft::WRL::Make<TBASE>(std::forward<TArgs>(args)...));
    CheckAllocationPtr(instance);
    return instance.CopyTo(ppInstance);
  }

  template <class TInterface>
  static HRESULT STDMETHODCALLTYPE
  CreateInstance(IInspectable *pOuter, IInspectable **ppInner, TInterface **ppInstance) {
    ValidateOutputArgPtr(ppInstance);
    if (pOuter != nullptr && ppInner == nullptr)
      return E_UNEXPECTED;

    if (pOuter != nullptr) {
      return CreateAggregatedInstance(pOuter, ppInner, ppInstance);
    } else {
      ::Microsoft::WRL::ComPtr<IInspectable> pInstance;
      ActivateInstance(&pInstance);
      return pInstance.CopyTo(ppInstance);
    }
  }

  template <class TInterface, class... TArgs>
  static HRESULT STDMETHODCALLTYPE
  CreateInstance(IInspectable *pOuter, IInspectable **ppInner, TInterface **ppInstance, TArgs &&... args) {
    ValidateOutputArgPtr(ppInstance);
    if (pOuter != nullptr && ppInner == nullptr)
      return E_UNEXPECTED;

    if (pOuter != nullptr) {
      return CreateAggregatedInstance(pOuter, ppInner, ppInstance, std::forward<TArgs>(args)...);
    } else {
      ::Microsoft::WRL::ComPtr<IInspectable> pInstance;
      ActivateInstance(&pInstance, std::forward<TArgs>(args)...);
      return pInstance.CopyTo(ppInstance);
    }
  }

 private:
  template <class TInterface>
  static HRESULT CreateAggregatedInstance(IInspectable *pOuter, IInspectable **ppInner, TInterface **ppInstance) {
    ValidateOutputArgPtr(ppInstance);
    ValidateOutputArgPtr(ppInner);
    if (pOuter != nullptr)
      return E_INVALIDARG;

    ctl::AggregableComObject<TBASE> *pInstance = nullptr;
    ctl::AggregableComObject<TBASE>::CreateInstance(pOuter, &pInstance);

    *ppInner = reinterpret_cast<IInspectable *>(static_cast<ctl::INonDelegatingInspectable *>(pInstance));

    *ppInstance = static_cast<TInterface *>(pInstance);
    (*ppInstance)->AddRef();
    return S_OK;
  }

  template <class TInterface, class... TArgs>
  static HRESULT
  CreateAggregatedInstance(IInspectable *pOuter, IInspectable **ppInner, TInterface **ppInstance, TArgs &&... args) {
    ValidateOutputArgPtr(ppInstance);
    ValidateOutputArgPtr(ppInner);
    if (pOuter != nullptr)
      return E_INVALIDARG;

    ctl::AggregableComObject<TBASE> *pInstance = nullptr;
    ctl::AggregableComObject<TBASE>::CreateInstance(pOuter, &pInstance, std::forward<TArgs>(args)...);

    *ppInner = reinterpret_cast<IInspectable *>(static_cast<ctl::INonDelegatingInspectable *>(pInstance));

    *ppInstance = static_cast<TInterface *>(pInstance);
    (*ppInstance)->AddRef();
    return S_OK;
  }
};

} // namespace ctl
